import logging
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn.parallel import data_parallel
from torch.utils.data import DataLoader
from typeguard import typechecked

from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.fileio.npy_scp import NpyScpWriter
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.forward_adaptor import ForwardAdaptor
from espnet2.train.abs_espnet_model import AbsESPnetModel


@torch.no_grad()
@typechecked
def collect_stats(
    model: Union[AbsESPnetModel, None],
    train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
    output_dir: Path,
    ngpu: Optional[int],
    log_interval: Optional[int],
    write_collected_feats: bool,
) -> None:
    """Perform on collect_stats mode.

    Running for deriving the shape information from data
    and gathering statistics.
    This method is used before executing train().

    """

    npy_scp_writers = {}
    for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
        if log_interval is None:
            try:
                log_interval = max(len(itr) // 20, 10)
            except TypeError:
                log_interval = 100

        sum_dict = defaultdict(lambda: 0)
        sq_dict = defaultdict(lambda: 0)
        count_dict = defaultdict(lambda: 0)

        with DatadirWriter(output_dir / mode) as datadir_writer:
            for iiter, (keys, batch) in enumerate(itr, 1):
                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")

                # 1. Write shape file
                for name in batch:
                    if name.endswith("_lengths"):
                        continue
                    for i, (key, data) in enumerate(zip(keys, batch[name])):
                        if f"{name}_lengths" in batch:
                            lg = int(batch[f"{name}_lengths"][i])
                            data = data[:lg]
                        datadir_writer[f"{name}_shape"][key] = ",".join(
                            map(str, data.shape)
                        )

                if model is not None:
                    # 2. Extract feats
                    if ngpu <= 1:
                        data = model.collect_feats(**batch)
                    else:
                        # Note that data_parallel can parallelize only "forward()"
                        data = data_parallel(
                            ForwardAdaptor(model, "collect_feats"),
                            (),
                            range(ngpu),
                            module_kwargs=batch,
                        )

                    # 3. Calculate sum and square sum
                    for key, v in data.items():
                        for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
                            # Truncate zero-padding region
                            if f"{key}_lengths" in data:
                                length = data[f"{key}_lengths"][i]
                                # seq: (Length, Dim, ...)
                                seq = seq[:length]
                            else:
                                # seq: (Dim, ...) -> (1, Dim, ...)
                                seq = seq[None]
                            # Accumulate value, its square, and count
                            sum_dict[key] += seq.sum(0)
                            sq_dict[key] += (seq**2).sum(0)
                            count_dict[key] += len(seq)

                            # 4. [Option] Write derived features as npy format file.
                            if write_collected_feats:
                                # Instantiate NpyScpWriter for the first iteration
                                if (key, mode) not in npy_scp_writers:
                                    p = output_dir / mode / "collect_feats"
                                    npy_scp_writers[(key, mode)] = NpyScpWriter(
                                        p / f"data_{key}", p / f"{key}.scp"
                                    )
                                # Save array as npy file
                                npy_scp_writers[(key, mode)][uttid] = seq

                if iiter % log_interval == 0:
                    logging.info(f"Niter: {iiter}")

        for key in sum_dict:
            np.savez(
                output_dir / mode / f"{key}_stats.npz",
                count=count_dict[key],
                sum=sum_dict[key],
                sum_square=sq_dict[key],
            )

        # batch_keys and stats_keys are used by aggregate_stats_dirs.py
        with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
            f.write(
                "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
            )
        with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
            f.write("\n".join(sum_dict) + "\n")
